/*

 * Licensed to the Apache Software Foundation (ASF) under one

 * or more contributor license agreements.  See the NOTICE file

 * distributed with this work for additional information

 * regarding copyright ownership.  The ASF licenses this file

 * to you under the Apache License, Version 2.0 (the

 * "License"); you may not use this file except in compliance

 * with the License.  You may obtain a copy of the License at

 *

 *     http://www.apache.org/licenses/LICENSE-2.0

 *

 * Unless required by applicable law or agreed to in writing, software

 * distributed under the License is distributed on an "AS IS" BASIS,

 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

 * See the License for the specific language governing permissions and

 * limitations under the License.

 */

package com.bff.gaia.unified.sdk.testing;



import com.bff.gaia.unified.sdk.transforms.Combine.CombineFn;

import org.hamcrest.Matcher;



import java.util.ArrayList;

import java.util.Arrays;

import java.util.Collections;

import java.util.List;



import static org.hamcrest.Matchers.is;

import static org.junit.Assert.assertThat;



/**

 * Utilities for testing {@link CombineFn CombineFns}. Ensures that the {@link CombineFn} gives

 * correct results across various permutations and shardings of the input.

 */

public class CombineFnTester {

  /**

   * Tests that the the {@link CombineFn}, when applied to the provided input, produces the provided

   * output. Tests a variety of permutations of the input.

   */

  public static <InputT, AccumT, OutputT> void testCombineFn(

	  CombineFn<InputT, AccumT, OutputT> fn, List<InputT> input, final OutputT expected) {

    testCombineFn(fn, input, is(expected));

    Collections.shuffle(input);

    testCombineFn(fn, input, is(expected));

  }



  public static <InputT, AccumT, OutputT> void testCombineFn(

	  CombineFn<InputT, AccumT, OutputT> fn, List<InputT> input, Matcher<? super OutputT> matcher) {

    int size = input.size();

    checkCombineFnShardsMultipleOrders(fn, Collections.singletonList(input), matcher);

    checkCombineFnShardsMultipleOrders(fn, shardEvenly(input, 2), matcher);

    if (size > 4) {

      checkCombineFnShardsMultipleOrders(fn, shardEvenly(input, size / 2), matcher);

      checkCombineFnShardsMultipleOrders(

          fn, shardEvenly(input, (int) (size / Math.sqrt(size))), matcher);

    }

    checkCombineFnShardsMultipleOrders(fn, shardExponentially(input, 1.4), matcher);

    checkCombineFnShardsMultipleOrders(fn, shardExponentially(input, 2), matcher);

    checkCombineFnShardsMultipleOrders(fn, shardExponentially(input, Math.E), matcher);

  }



  private static <InputT, AccumT, OutputT> void checkCombineFnShardsMultipleOrders(

      CombineFn<InputT, AccumT, OutputT> fn,

      List<? extends Iterable<InputT>> shards,

      Matcher<? super OutputT> matcher) {

    checkCombineFnShardsSingleMerge(fn, shards, matcher);

    checkCombineFnShardsWithEmptyAccumulators(fn, shards, matcher);

    checkCombineFnShardsIncrementalMerging(fn, shards, matcher);

    Collections.shuffle(shards);

    checkCombineFnShardsSingleMerge(fn, shards, matcher);

    checkCombineFnShardsWithEmptyAccumulators(fn, shards, matcher);

    checkCombineFnShardsIncrementalMerging(fn, shards, matcher);

  }



  private static <InputT, AccumT, OutputT> void checkCombineFnShardsSingleMerge(

      CombineFn<InputT, AccumT, OutputT> fn,

      Iterable<? extends Iterable<InputT>> shards,

      Matcher<? super OutputT> matcher) {

    List<AccumT> accumulators = combineInputs(fn, shards);

    AccumT merged = fn.mergeAccumulators(accumulators);

    assertThat(fn.extractOutput(merged), matcher);

  }



  private static <InputT, AccumT, OutputT> void checkCombineFnShardsWithEmptyAccumulators(

      CombineFn<InputT, AccumT, OutputT> fn,

      Iterable<? extends Iterable<InputT>> shards,

      Matcher<? super OutputT> matcher) {

    List<AccumT> accumulators = combineInputs(fn, shards);

    accumulators.add(0, fn.createAccumulator());

    accumulators.add(fn.createAccumulator());

    AccumT merged = fn.mergeAccumulators(accumulators);

    assertThat(fn.extractOutput(merged), matcher);

  }



  private static <InputT, AccumT, OutputT> void checkCombineFnShardsIncrementalMerging(

      CombineFn<InputT, AccumT, OutputT> fn,

      List<? extends Iterable<InputT>> shards,

      Matcher<? super OutputT> matcher) {

    AccumT accumulator = shards.isEmpty() ? fn.createAccumulator() : null;

    for (AccumT inputAccum : combineInputs(fn, shards)) {

      if (accumulator == null) {

        accumulator = inputAccum;

      } else {

        accumulator = fn.mergeAccumulators(Arrays.asList(accumulator, inputAccum));

      }

      fn.extractOutput(accumulator); // Extract output to simulate multiple firings

    }

    assertThat(fn.extractOutput(accumulator), matcher);

  }



  private static <InputT, AccumT, OutputT> List<AccumT> combineInputs(

	  CombineFn<InputT, AccumT, OutputT> fn, Iterable<? extends Iterable<InputT>> shards) {

    List<AccumT> accumulators = new ArrayList<>();

    int maybeCompact = 0;

    for (Iterable<InputT> shard : shards) {

      AccumT accumulator = fn.createAccumulator();

      for (InputT elem : shard) {

        accumulator = fn.addInput(accumulator, elem);

      }

      if (maybeCompact++ % 2 == 0) {

        accumulator = fn.compact(accumulator);

      }

      accumulators.add(accumulator);

    }

    return accumulators;

  }



  private static <T> List<List<T>> shardEvenly(List<T> input, int numShards) {

    List<List<T>> shards = new ArrayList<>(numShards);

    for (int i = 0; i < numShards; i++) {

      shards.add(input.subList(i * input.size() / numShards, (i + 1) * input.size() / numShards));

    }

    return shards;

  }



  private static <T> List<List<T>> shardExponentially(List<T> input, double base) {

    assert base > 1.0;

    List<List<T>> shards = new ArrayList<>();

    int end = input.size();

    while (end > 0) {

      int start = (int) (end / base);

      shards.add(input.subList(start, end));

      end = start;

    }

    return shards;

  }

}